#!/usr/bin/env python3

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# This tool is based on the Spark merge_spark_pr script:
# https://github.com/apache/spark/blob/master/dev/merge_spark_pr.py

from collections import defaultdict, Counter

import jira
import re
import sys

PROJECT = "AIRFLOW"


try:
    import click
except ImportError:
    print("Could not find the click library. Run 'sudo pip install click' to install.")
    sys.exit(-1)

try:
    import git
except ImportError:
    print("Could not import git. Run 'sudo pip install gitpython' to install")
    sys.exit(-1)

JIRA_BASE = "https://issues.apache.org/jira/browse"
JIRA_API_BASE = "https://issues.apache.org/jira"

GIT_COMMIT_FIELDS = ['id', 'author_name', 'author_email', 'date', 'subject', 'body']
GIT_LOG_FORMAT = '%x1f'.join(['%h', '%an', '%ae', '%ad', '%s', '%b']) + '%x1e'

STATUS_COLUR_MAP = {
    'Resolved': 'green',
    'Open': 'red',
    'Closed': 'yellow'
}


def get_jiras_for_version(version):
    asf_jira = jira.client.JIRA({'server': JIRA_API_BASE})

    start_at = 0
    page_size = 50
    while True:
        results = asf_jira.search_issues(
            '''PROJECT={} and fixVersion={} AND
               (Resolution IS NULL OR Resolution NOT IN (Duplicate, "Cannot Reproduce"))
               ORDER BY updated DESC
            '''.format(PROJECT, version),
            maxResults=page_size,
            startAt=start_at,
        )

        for r in results:
            yield r

        if len(results) < page_size:
            break

        start_at += page_size


issue_re = re.compile(r".*? (AIRFLOW[- ][0-9]{1,6}) (?: \]|\s|: )", flags=re.X | re.I)
issue_id_re = re.compile(r"AIRFLOW[- ]([0-9]{1,6})", flags=re.I)
pr_re = re.compile(r"(.*)Closes (#[0-9]{1,6})", flags=re.DOTALL)
pr_title_re = re.compile(r".*\((#[0-9]{1,6})\)$")

def _process_git_log_item(log_item):
    match = pr_title_re.match(log_item['subject'])
    if match:
        log_item['pull_request'] = match.group(1)
    elif 'body' in log_item:
        match = pr_re.match(log_item['body'])
        if match:
            log_item['pull_request'] = match.group(2)
        else:
            log_item['pull_request'] = '-'
    else:
        log_item['pull_request'] = '-'

def normalize_issue_id(issue_id):
    return "AIRFLOW-" + issue_id_re.match(issue_id)[1]

def get_merged_issues(repo, version, previous_version=None):
    log_args = ['--format={}'.format(GIT_LOG_FORMAT)]
    if previous_version:
        log_args.append(previous_version + "..")
    log = repo.git.log(*log_args)

    log = log.strip('\n\x1e').split("\x1e")
    log = [row.strip().split("\x1f") for row in log]
    log = [dict(zip(GIT_COMMIT_FIELDS, row)) for row in log]

    merges = defaultdict(list)
    for log_item in log:
        issue_id = None
        issue_ids = issue_re.findall(log_item['subject'])

        _process_git_log_item(log_item)

        log_item['merged'] = True

        for issue_id in issue_ids:
            merges[normalize_issue_id(issue_id)].append(log_item)

    return merges


def get_commits_from_master(repo, issue):
    log = repo.git.log(
        '--format={}'.format(GIT_LOG_FORMAT),
        '--grep',
        r'.*{issue}\(\]\|:\|\s\)'.format(issue=issue.key),
        'origin/master')
    if not log:
        return []

    log = log.strip('\n\x1e').split("\x1e")
    log = [row.strip().split("\x1f") for row in log]
    log = [dict(zip(GIT_COMMIT_FIELDS, row)) for row in log]
    merges = []
    for log_item in log:
        _process_git_log_item(log_item)
        log_item['merged'] = False
        # We may have pulled up a message from not the subject
        if issue.key in issue_re.findall(log_item['subject']):
            merges.append(log_item)

    return merges


@click.group()
def cli():
    r"""
    This tool should be used by Airflow Release Manager to verify what Jira's
     were merged in the current working branch.

        airflow-jira compare <target_version>
    """


@cli.command(short_help='Compare a jira target version against git merges')
@click.argument('target_version', default=None)
@click.option('--previous-version',
              'previous_version',
              help="Specify the previous tag on the working branch to limit"
                   " searching for few commits to find the cherry-picked commits")
@click.option('--unmerged', 'unmerged_only', help="Show unmerged issues only", is_flag=True)
def compare(target_version, previous_version=None, unmerged_only=False):
    repo = git.Repo(".", search_parent_directories=True)
    # Get a list of issues/PRs that have been commited on the current branch.
    branch_merges = get_merged_issues(repo, target_version, previous_version)

    issues = get_jiras_for_version(target_version)

    num_merged = 0
    num_unmerged = Counter()

    # :<18 says left align, pad to 18
    # :<50.50 truncates after 50 chars
    # !s forces as string - some of the Jira objects have a string method, but
    #    Py3 doesn't call by default
    formatstr = "{id:<18}|{typ!s:<12}||{priority!s:<10}||{status!s}|" \
                "{description:<83.83}|{merged:<6}|{pr:<6}|{commit:>9}"

    print(formatstr.format(
        id="ISSUE ID",
        typ="TYPE",
        priority="PRIORITY",
        status="STATUS".ljust(10),
        description="DESCRIPTION",
        merged="MERGED",
        pr="PR",
        commit="COMMIT"))

    for issue in issues:

        # Put colour on the status field. Since it will have non-printable
        # characters we can't us string format to limit the length
        status = issue.fields.status.name
        if status in STATUS_COLUR_MAP:
            status = click.style(status[:10].ljust(10), STATUS_COLUR_MAP[status])
        else:
            status = status[:10].ljust(10)

        # Find the merges in master targeting this issue
        master_merges = get_commits_from_master(repo, issue)

        on_branch = {
            m['pull_request']: m
            for m in branch_merges.get(issue.key, [])
        }

        def print_merge_info(merge, printed_desc):
            nonlocal num_merged

            is_merged = merge['merged']

            if is_merged:
                num_merged += 1
                if unmerged_only:
                    return False
            else:
                num_unmerged[issue.fields.status.name] += 1

            if not printed_desc:
                # Only print info on first line for each issue
                fields = dict(
                    id=issue.key,
                    typ=issue.fields.issuetype,
                    priority=issue.fields.priority,
                    status=status,
                    description=issue.fields.summary,
                )
            else:
                fields = dict(
                    id=issue.key,
                    typ="",
                    priority="",
                    status=" " * 10,
                    description="",
                )
            print(formatstr.format(
                **fields,
                merged=is_merged,
                pr=merge['pull_request'],
                commit=merge['id']))
            return True

        printed_desc = False
        for merge in master_merges:
            if merge['pull_request'] in on_branch:
                merge = on_branch[merge['pull_request']]

            printed_desc = print_merge_info(merge, printed_desc)

        if not master_merges:
            if on_branch:
                for merge in branch_merges.get(issue.key):
                    printed_desc = print_merge_info(merge, printed_desc)
            else:
                # No merges, issue likely still open
                print_merge_info({
                    'merged': 0,
                    'pull_request': '-',
                    'id': '-',
                }, printed_desc)

    print("Commits on branch: {0:d}, {1:d} ({2}) yet to be cherry-picked".format(num_merged, sum(num_unmerged.values()), dict(num_unmerged)))


@cli.command(short_help='Build a CHANGELOG grouped by Jira Issue type')
@click.argument('previous_version')
@click.argument('target_version')
def changelog(previous_version, target_version):
    repo = git.Repo(".", search_parent_directories=True)
    # Get a list of issues/PRs that have been commited on the current branch.
    log_args = ['--format={}'.format(GIT_LOG_FORMAT), previous_version + ".." + target_version]
    log = repo.git.log(*log_args)

    log = log.strip('\n\x1e').split("\x1e")
    log = [row.strip().split("\x1f") for row in log]
    log = [dict(zip(GIT_COMMIT_FIELDS, row)) for row in log]

    asf_jira = jira.client.JIRA({'server': JIRA_API_BASE})

    sections = defaultdict(list)

    batch = []

    def process_batch(batch):
        page_size = 50
        results = {i.key: i for i in asf_jira.search_issues(
            'PROJECT={} AND key IN ({})'.format(
                PROJECT,
                ','.join(key for key, _ in batch)),
            maxResults=page_size,
        )}

        for key, subject in batch:
            sections[results[key].fields.issuetype.name].append(subject)

    for commit in log:
        tickets = issue_re.findall(commit['subject'])

        if not tickets or 'AIRFLOW-XXX' in tickets:
            # TODO: Guess by files changed?
            sections['uncategorized'].append(commit['subject'])

        else:
            # The Jira API is kinda slow, so ask for 50 issues at a time.
            batch.append((normalize_issue_id(tickets[0]), commit['subject']))

            if len(batch) == 50:
                process_batch(batch)
                batch = []

    for section, lines in sections.items():
        print(section)
        print('"' * len(section))
        for line in lines:
            print('-', line)
        print()


if __name__ == "__main__":
    import doctest
    (failure_count, test_count) = doctest.testmod()
    if failure_count:
        exit(-1)
    try:
        cli()
    except Exception:
        raise
